"""
Pytorch implementation of SegNet (https://arxiv.org/pdf/1511.00561.pdf)
"""

from __future__ import print_function
from collections import OrderedDict
import torch
import torch.nn as nn
import torchvision.models as models

F = nn.functional
DEBUG = False

vgg16_dims = [
    (64, 64, "M"),  # Stage - 1
    (128, 128, "M"),  # Stage - 2
    (256, 256, 256, "M"),  # Stage - 3
    (512, 512, 512, "M"),  # Stage - 4
    (512, 512, 512, "M"),  # Stage - 5
]

decoder_dims = [
    ("U", 512, 512, 512),  # Stage - 5
    ("U", 512, 512, 512),  # Stage - 4
    ("U", 256, 256, 256),  # Stage - 3
    ("U", 128, 128),  # Stage - 2
    ("U", 64, 64),  # Stage - 1
]


class SegNet(nn.Module):
    net_name = "segnet"

    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes

        # Encoder layers

        self.encoder_conv_00 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
                nn.BatchNorm2d(64),
            ]
        )
        self.encoder_conv_01 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
                nn.BatchNorm2d(64),
            ]
        )
        self.encoder_conv_10 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
            ]
        )
        self.encoder_conv_11 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
            ]
        )
        self.encoder_conv_20 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
            ]
        )
        self.encoder_conv_21 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
            ]
        )
        self.encoder_conv_22 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
            ]
        )
        self.encoder_conv_30 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
            ]
        )
        self.encoder_conv_31 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
            ]
        )
        self.encoder_conv_32 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
            ]
        )
        self.encoder_conv_40 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
            ]
        )
        self.encoder_conv_41 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
            ]
        )
        self.encoder_conv_42 = nn.Sequential(
            *[
                nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
            ]
        )

        # Decoder layers

        self.decoder_convtr_42 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=512, out_channels=512, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(512),
            ]
        )
        self.decoder_convtr_41 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=512, out_channels=512, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(512),
            ]
        )
        self.decoder_convtr_40 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=512, out_channels=512, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(512),
            ]
        )
        self.decoder_convtr_32 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=512, out_channels=512, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(512),
            ]
        )
        self.decoder_convtr_31 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=512, out_channels=512, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(512),
            ]
        )
        self.decoder_convtr_30 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=512, out_channels=256, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(256),
            ]
        )
        self.decoder_convtr_22 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=256, out_channels=256, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(256),
            ]
        )
        self.decoder_convtr_21 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=256, out_channels=256, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(256),
            ]
        )
        self.decoder_convtr_20 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=256, out_channels=128, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(128),
            ]
        )
        self.decoder_convtr_11 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=128, out_channels=128, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(128),
            ]
        )
        self.decoder_convtr_10 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=128, out_channels=64, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(64),
            ]
        )
        self.decoder_convtr_01 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=64, out_channels=64, kernel_size=3, padding=1
                ),
                nn.BatchNorm2d(64),
            ]
        )
        self.decoder_convtr_00 = nn.Sequential(
            *[
                nn.ConvTranspose2d(
                    in_channels=64,
                    out_channels=self.num_classes,
                    kernel_size=3,
                    padding=1,
                )
            ]
        )

    def forward(self, input_img):
        """
        Forward pass `input_img` through the network
        """

        # Encoder

        if input_img.shape[1] != 3:
            input_img = torch.cat((input_img, input_img, input_img), 1)

        # Encoder Stage - 1
        dim_0 = input_img.size()
        x_00 = F.relu(self.encoder_conv_00(input_img))
        x_01 = F.relu(self.encoder_conv_01(x_00))
        x_0, indices_0 = F.max_pool2d(
            x_01, kernel_size=2, stride=2, return_indices=True
        )

        # Encoder Stage - 2
        dim_1 = x_0.size()
        x_10 = F.relu(self.encoder_conv_10(x_0))
        x_11 = F.relu(self.encoder_conv_11(x_10))
        x_1, indices_1 = F.max_pool2d(
            x_11, kernel_size=2, stride=2, return_indices=True
        )

        # Encoder Stage - 3
        dim_2 = x_1.size()
        x_20 = F.relu(self.encoder_conv_20(x_1))
        x_21 = F.relu(self.encoder_conv_21(x_20))
        x_22 = F.relu(self.encoder_conv_22(x_21))
        x_2, indices_2 = F.max_pool2d(
            x_22, kernel_size=2, stride=2, return_indices=True
        )

        # Encoder Stage - 4
        dim_3 = x_2.size()
        x_30 = F.relu(self.encoder_conv_30(x_2))
        x_31 = F.relu(self.encoder_conv_31(x_30))
        x_32 = F.relu(self.encoder_conv_32(x_31))
        x_3, indices_3 = F.max_pool2d(
            x_32, kernel_size=2, stride=2, return_indices=True
        )

        # Encoder Stage - 5
        dim_4 = x_3.size()
        x_40 = F.relu(self.encoder_conv_40(x_3))
        x_41 = F.relu(self.encoder_conv_41(x_40))
        x_42 = F.relu(self.encoder_conv_42(x_41))
        x_4, indices_4 = F.max_pool2d(
            x_42, kernel_size=2, stride=2, return_indices=True
        )

        # Decoder

        dim_d = x_4.size()

        # Decoder Stage - 5
        x_4d = F.max_unpool2d(
            x_4, indices_4, kernel_size=2, stride=2, output_size=dim_4
        )
        x_42d = F.relu(self.decoder_convtr_42(x_4d))
        x_41d = F.relu(self.decoder_convtr_41(x_42d))
        x_40d = F.relu(self.decoder_convtr_40(x_41d))
        dim_4d = x_40d.size()

        # Decoder Stage - 4
        x_3d = F.max_unpool2d(
            x_40d, indices_3, kernel_size=2, stride=2, output_size=dim_3
        )
        x_32d = F.relu(self.decoder_convtr_32(x_3d))
        x_31d = F.relu(self.decoder_convtr_31(x_32d))
        x_30d = F.relu(self.decoder_convtr_30(x_31d))
        dim_3d = x_30d.size()

        # Decoder Stage - 3
        x_2d = F.max_unpool2d(
            x_30d, indices_2, kernel_size=2, stride=2, output_size=dim_2
        )
        x_22d = F.relu(self.decoder_convtr_22(x_2d))
        x_21d = F.relu(self.decoder_convtr_21(x_22d))
        x_20d = F.relu(self.decoder_convtr_20(x_21d))
        dim_2d = x_20d.size()

        # Decoder Stage - 2
        x_1d = F.max_unpool2d(
            x_20d, indices_1, kernel_size=2, stride=2, output_size=dim_1
        )
        x_11d = F.relu(self.decoder_convtr_11(x_1d))
        x_10d = F.relu(self.decoder_convtr_10(x_11d))
        dim_1d = x_10d.size()

        # Decoder Stage - 1
        x_0d = F.max_unpool2d(
            x_10d, indices_0, kernel_size=2, stride=2, output_size=dim_0
        )
        x_01d = F.relu(self.decoder_convtr_01(x_0d))
        x_00d = self.decoder_convtr_00(x_01d)
        dim_0d = x_00d.size()

        # x_softmax = F.softmax(x_00d, dim=1)

        if DEBUG:
            print("dim_0: {}".format(dim_0))
            print("dim_1: {}".format(dim_1))
            print("dim_2: {}".format(dim_2))
            print("dim_3: {}".format(dim_3))
            print("dim_4: {}".format(dim_4))

            print("dim_d: {}".format(dim_d))
            print("dim_4d: {}".format(dim_4d))
            print("dim_3d: {}".format(dim_3d))
            print("dim_2d: {}".format(dim_2d))
            print("dim_1d: {}".format(dim_1d))
            print("dim_0d: {}".format(dim_0d))

        return torch.sigmoid(x_00d)  # , x_softmax
